import pandas as pd
from pandas.tseries.offsets import Week
from tempfile import TemporaryDirectory


FREQS = {'day': 'D',
         'daily': 'D',
         'week': Week(1),
         'weekly': Week(1),
         'month': 'M',
         'monthly': 'M',
         'quarter': 'Q',
         'quarterly': 'Q',
         'year': 'A-{}',
         'annual': 'A-{}'
         }
MONTHS = {1: 'JAN', 2: 'FEB', 3: 'MAR', 4: 'APR', 5: 'MAY', 6: 'JUN',
          7: 'JUL', 8: 'AUG', 9: 'SEP', 10: 'OCT', 11: 'NOV', 12: 'DEC'}


def save_stat(stat, filename, formatting='%.1f', output='console'):
    """
    Saves a statistic to a .txt output file (if output=='paper') or prints it to console (if output=='console').

    Parameters
    ----------
    stat: float
        Used statistic.
    filename: string
        Name of the file where statistic is saved (without file extension).
    formatting: string, default '%.1f'
        Output format for the statistic.
    output: {'console', 'paper'}, default 'console'
        If 'console', print statistic to console.
        If 'paper' save statistic to a .txt output file.
    """

    if output == 'paper':
        with open(f'cat_rfs/output/stats/{filename}.txt', 'w') as f:
            f.write(formatting % stat)
    elif output == 'console':
        print(filename, ':', formatting % stat)

    pass


def find_frequency(series):
    """
    Find the frequency of a time series. If time series is irregularly spaced, return False.

    Parameters
    ----------
    series: pandas.Series
        A datetime time series.

    Returns
    -------
    String or False
        If time series corresponds with a regularly distributed one, return its string code
        ('D', 'crsp_tradingday', 'M', 'Q', 'A-Jun'). If not, return False.
    """
    FREQS = {'daily': 'D',
             'monthly': 'M',
             'quarterly': 'Q',
             'annual': 'A-{}',
             'crsp_tradingday': 'crsp_tradingday'}

    for freq, return_val in FREQS.items():
        if check_ts_integrity(series, freq):
            return return_val.format(MONTHS[series.iloc[1].month]) if freq == 'annual' else return_val

    return False

def check_ts_integrity(series, freq):
    """
    Checks that a time series is regularly distributed between its min and max dates,
    i.e. that time series does not contain extra observations (e.g. monthly series has middle-of-month observation)
    or that there are gaps.

    Parameters
    ----------
    series: pandas.Series
        A datetime time series.
    freq: {'day', 'daily', 'week', 'weekly', 'month', 'monthly', 'quarter', 'quarterly', 'year', 'annual'}
        Data frequency. Weekly series can be any weekday. Month and Quarter and must be end of period.
        Year must be end of any month.

    Returns
    -------
    Boolean

    Examples
    --------
    assert check_ts_integrity(df['date'], 'monthly')
    """
    series = series.copy()

    assert isinstance(series, pd.Series), 'series must be Pandas Series object.'

    assert freq in FREQS.keys(), f"freq must be in {str(list(FREQS.keys()))}."

    name = 0 if not series.name else series.name

    if freq == 'week' or freq == 'weekly':
        bmk_series = pd.date_range(series.min() - Week(1), series.max(), freq=FREQS[freq]).to_frame(index=False)
    elif freq == 'year' or freq == 'annual':
        bmk_series = pd.date_range(series.min(), series.max(),
                                   freq=FREQS[freq].format(MONTHS[series.min().month])).to_frame(index=False)
    else:
        bmk_series = pd.date_range(series.min(), series.max(), freq=FREQS[freq]).to_frame(index=False)

    bmk_series = bmk_series.rename(columns={0: name})

    if bmk_series.shape[0] == series.shape[0]:
        if bmk_series.shape[0] == bmk_series.merge(series.to_frame(), on=name, how='outer').shape[0]:

            return True

    return False

def keep_perils(df, colname='perils'):
    """Drop geographies and keep only peril types in colname of df."""
    df = df.copy()

    def _fun(i):
        if i is not None:
            newlist = []
            for j in [s.lstrip() for s in i.split(',')]:
                newlist.append(j.split()[1])

            return ', '.join(list(set(newlist)))
        else:
            return ''

    df[colname] = df[colname].apply(_fun)

    return df


def list_iterate(iter_list_name, chunksize, output=None, quiet=False):
    """
    Consider a function that takes a list as one of its inputs and then performs some operation for all elements in
    that list. Using list_iterate as a decorator for that function enables the user to define additional input,
    'chunksize', in which case the input list is chopped into a chunks and the function is performed separately for
    all those chunks.

    Parameters
    ----------
    iter_list_name: string
        Name of the function input that defines the list that is being iterated over. Note that this input must be
        a keyword argument.
    chunksize: positive integer
        Determines how many items in iter_list_name are estimated at a time. Should target a computer with 128GB of RAM.
    output: {None, 'pd'}, default None
          * If None, simply iterate over the function without utilizing its output. This is generally used in cases
            where function will write its end product e.g. to a sqlite database with append.
          * If 'pd', function is assumed to return a pandas data frame (or a tuple containing multiple data frames).
            Write the output of an iteration to a temporary feather file. After performing all the loops, read the
            temporary files and append to one big data frame that is then returned.
    quiet: boolean, default False
        If True, don't print the status reports every time a new loop is started.

    Examples
    --------
    Consider a function that estimates beta for all firms in its input 'permno_list':
        def beta(permno_list, ...):
            ...
            return df

    But maybe calculating betas for all companies simultaneously takes too much computational resources, so it would
    be better to loop over firms in chunks of 1000 and estimate betas for those chunks. Using list_iterate, this can
    be achieved by simply adjusting the function definition as follows:
        @list_iterate('permno_list', chunksize=1000, output='pd')
        def beta(permno_list=[1, 2], ...):
            ...
            return df
    """
    def real_decorator(function):
        def wrapper(*args, **kwargs):
            assert isinstance(chunksize, int)
            iter_list = kwargs[iter_list_name]
            ilist = [iter_list[k:k + chunksize] for k in range(0, len(iter_list), chunksize)]

            if len(ilist) == 1:
                df = function(*args, **kwargs)
            else:
                with TemporaryDirectory(prefix='list_iterate_') as tmpdir:
                    for n, i in enumerate(ilist):
                        if not quiet:
                            print(function.__name__ + f': iterating over {iter_list_name} ({n + 1}/{len(ilist)})')
                        kwargs[iter_list_name] = i
                        df = function(*args, **kwargs)
                        if output == 'pd':
                            if isinstance(df, pd.DataFrame):
                                tup = False
                                df.to_feather(tmpdir + f'\\batch{n}.feather')
                            elif isinstance(df, tuple):
                                tup = True
                                length = len(df)
                                for idx, dfe in enumerate(df):
                                    if output == 'pd':
                                        dfe.to_feather(tmpdir + f'\\batch{n}_df{idx}.feather')
                            else:
                                raise TypeError(
                                    f'Expected result to be a DataFrame or a tuple, but got {type(result).__name__}.'
                                )
                        del df

                    if output == 'pd':
                        if tup is False:
                            df = pd.read_feather(tmpdir + '\\batch0.feather')
                            for n in range(1, len(ilist)):
                                df = df.append(pd.read_feather(tmpdir + f'\\batch{n}.feather'), ignore_index=True)
                        else:
                            df = []
                            for idx in range(0, length):
                                df.append(pd.read_feather(tmpdir + f'\\batch0_df{idx}.feather'))
                                for n in range(1, len(ilist)):
                                    df[idx] = df[idx].append(pd.read_feather(tmpdir + f'\\batch{n}_df{idx}.feather'),
                                                             ignore_index=True)

            if output == 'pd':
                return df
            pass
        return wrapper
    return real_decorator
